{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "file_extension": ".py",
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.2"
    },
    "mimetype": "text/x-python",
    "name": "python",
    "npconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": 3,
    "colab": {
      "name": "cyclegan.ipynb",
      "provenance": [],
      "private_outputs": true,
      "toc_visible": true
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "v1CUZ0dkOo_F"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "qmkj-80IHxnd",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_xnMOsbqHz61"
      },
      "source": [
        "# CycleGAN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Ds4o1h4WHz9U"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/generative/cyclegan\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\" />在 tensorFlow.google.cn 上查看</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/zh-cn/tutorials/generative/cyclegan.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\" />在 Google Colab 中运行</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/zh-cn/tutorials/generative/cyclegan.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\" />在 GitHub 上查看源代码</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/zh-cn/tutorials/generative/cyclegan.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\" />下载 notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TE8yKcX672WZ",
        "colab_type": "text"
      },
      "source": [
        "Note: 我们的 TensorFlow 社区翻译了这些文档。因为社区翻译是尽力而为， 所以无法保证它们是最准确的，并且反映了最新的\n",
        "[官方英文文档](https://www.tensorflow.org/?hl=en)。如果您有改进此翻译的建议， 请提交 pull request 到\n",
        "[tensorflow/docs](https://github.com/tensorflow/docs) GitHub 仓库。要志愿地撰写或者审核译文，请加入\n",
        "[docs-zh-cn@tensorflow.org Google Group](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs-zh-cn)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ITZuApL56Mny"
      },
      "source": [
        "本笔记演示了使用条件 GAN 进行的未配对图像到图像转换，如[使用循环一致的对抗网络进行未配对图像到图像转换](https://arxiv.org/abs/1703.10593) 中所述，也称之为 CycleGAN。论文提出了一种可以捕捉图像域特征并找出如何将这些特征转换为另一个图像域的方法，而无需任何成对的训练样本。\n",
        "\n",
        "本笔记假定您熟悉 Pix2Pix，您可以在 [Pix2Pix 教程](https://tensorflow.google.cn/tutorials/generative/pix2pix)中了解有关它的信息。CycleGAN 的代码与其相似，主要区别在于额外的损失函数，以及非配对训练数据的使用。\n",
        "\n",
        "CycleGAN 使用循环一致损失来使训练过程无需配对数据。换句话说，它可以从一个域转换到另一个域，而不需要在源域与目标域之间进行一对一映射。\n",
        "\n",
        "这为完成许多有趣的任务开辟了可能性，例如照片增强、图片着色、风格迁移等。您所需要的只是源数据集和目标数据集（仅仅是图片目录）\n",
        "\n",
        "![输出图像 1](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/horse2zebra_1.png?raw=1)\n",
        "![输出图像 2](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/horse2zebra_2.png?raw=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "e1_Y75QXJS6h"
      },
      "source": [
        "## 设定输入管线"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5fGHWOKPX4ta"
      },
      "source": [
        "安装 [tensorflow_examples](https://github.com/tensorflow/examples) 包，以导入生成器和判别器。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "bJ1ROiQxJ-vY",
        "colab": {}
      },
      "source": [
        "!pip install git+https://github.com/tensorflow/examples.git"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "lhSsUx9Nyb3t",
        "colab": {}
      },
      "source": [
        "try:\n",
        "  # %tensorflow_version 只在 Colab 中使用。\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "YfIk2es3hJEd",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "from tensorflow_examples.models.pix2pix import pix2pix\n",
        "\n",
        "import os\n",
        "import time\n",
        "import matplotlib.pyplot as plt\n",
        "from IPython.display import clear_output\n",
        "\n",
        "tfds.disable_progress_bar()\n",
        "AUTOTUNE = tf.data.experimental.AUTOTUNE"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "iYn4MdZnKCey"
      },
      "source": [
        "## 输入管线\n",
        "\n",
        "本教程训练一个模型，以将普通马图片转换为斑马图片。您可以在[此处](https://tensorflow.google.cn/datasets/datasets#cycle_gan)获取该数据集以及类似数据集。\n",
        "\n",
        "如[论文](https://arxiv.org/abs/1703.10593)所述，将随机抖动和镜像应用到训练集。这是一些避免过拟合的图像增强技术。\n",
        "\n",
        "这类似于 [pix2pix](https://tensorflow.google.cn/tutorials/generative/pix2pix#load_the_dataset) 中所做的工作。\n",
        "\n",
        "* 在随机抖动中，图片大小调整为 `286 x 286`，随后被随机裁剪为 `256 x 256`。\n",
        "* 在随机镜像中，图片会从左到右随机翻转。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "iuGVPOo7Cce0",
        "colab": {}
      },
      "source": [
        "dataset, metadata = tfds.load('cycle_gan/horse2zebra',\n",
        "                              with_info=True, as_supervised=True)\n",
        "\n",
        "train_horses, train_zebras = dataset['trainA'], dataset['trainB']\n",
        "test_horses, test_zebras = dataset['testA'], dataset['testB']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2CbTEt448b4R",
        "colab": {}
      },
      "source": [
        "BUFFER_SIZE = 1000\n",
        "BATCH_SIZE = 1\n",
        "IMG_WIDTH = 256\n",
        "IMG_HEIGHT = 256"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Yn3IwqhiIszt",
        "colab": {}
      },
      "source": [
        "def random_crop(image):\n",
        "  cropped_image = tf.image.random_crop(\n",
        "      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])\n",
        "\n",
        "  return cropped_image"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "muhR2cgbLKWW",
        "colab": {}
      },
      "source": [
        "# 将图像归一化到区间 [-1, 1] 内。\n",
        "def normalize(image):\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image = (image / 127.5) - 1\n",
        "  return image"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "fVQOjcPVLrUc",
        "colab": {}
      },
      "source": [
        "def random_jitter(image):\n",
        "  # 调整大小为 286 x 286 x 3\n",
        "  image = tf.image.resize(image, [286, 286],\n",
        "                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "\n",
        "  # 随机裁剪到 256 x 256 x 3\n",
        "  image = random_crop(image)\n",
        "\n",
        "  # 随机镜像\n",
        "  image = tf.image.random_flip_left_right(image)\n",
        "\n",
        "  return image"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "tyaP4hLJ8b4W",
        "colab": {}
      },
      "source": [
        "def preprocess_image_train(image, label):\n",
        "  image = random_jitter(image)\n",
        "  image = normalize(image)\n",
        "  return image"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VB3Z6D_zKSru",
        "colab": {}
      },
      "source": [
        "def preprocess_image_test(image, label):\n",
        "  image = normalize(image)\n",
        "  return image"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "RsajGXxd5JkZ",
        "colab": {}
      },
      "source": [
        "train_horses = train_horses.map(\n",
        "    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
        "    BUFFER_SIZE).batch(1)\n",
        "\n",
        "train_zebras = train_zebras.map(\n",
        "    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
        "    BUFFER_SIZE).batch(1)\n",
        "\n",
        "test_horses = test_horses.map(\n",
        "    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
        "    BUFFER_SIZE).batch(1)\n",
        "\n",
        "test_zebras = test_zebras.map(\n",
        "    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(\n",
        "    BUFFER_SIZE).batch(1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "e3MhJ3zVLPan",
        "colab": {}
      },
      "source": [
        "sample_horse = next(iter(train_horses))\n",
        "sample_zebra = next(iter(train_zebras))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "4pOYjMk_KfIB",
        "colab": {}
      },
      "source": [
        "plt.subplot(121)\n",
        "plt.title('Horse')\n",
        "plt.imshow(sample_horse[0] * 0.5 + 0.5)\n",
        "\n",
        "plt.subplot(122)\n",
        "plt.title('Horse with random jitter')\n",
        "plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "0KJyB9ENLb2y",
        "colab": {}
      },
      "source": [
        "plt.subplot(121)\n",
        "plt.title('Zebra')\n",
        "plt.imshow(sample_zebra[0] * 0.5 + 0.5)\n",
        "\n",
        "plt.subplot(122)\n",
        "plt.title('Zebra with random jitter')\n",
        "plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hvX8sKsfMaio"
      },
      "source": [
        "## 导入并重用 Pix2Pix 模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cGrL73uCd-_M"
      },
      "source": [
        "通过安装的 [tensorflow_examples](https://github.com/tensorflow/examples) 包导入 [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) 中的生成器和判别器。\n",
        "\n",
        "本教程中使用模型体系结构与 [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) 中所使用的非常相似。一些区别在于：\n",
        "\n",
        "* Cyclegan 使用 [instance normalization（实例归一化）](https://arxiv.org/abs/1607.08022)而不是 [batch normalization （批归一化）](https://arxiv.org/abs/1502.03167)。\n",
        "* [CycleGAN 论文](https://arxiv.org/abs/1703.10593)使用一种基于 `resnet` 的改进生成器。简单起见，本教程使用的是改进的 `unet` 生成器。\n",
        "\n",
        "这里训练了两个生成器（G 和 F）以及两个判别器（X 和 Y）。\n",
        "\n",
        "* 生成器 `G` 学习将图片 `X` 转换为 `Y`。 $(G: X -> Y)$\n",
        "* 生成器 `F` 学习将图片 `Y` 转换为 `X`。 $(F: Y -> X)$\n",
        "* 判别器 `D_X` 学习区分图片 `X` 与生成的图片 `X` (`F(Y)`)。\n",
        "* 判别器 `D_Y` 学习区分图片 `Y` 与生成的图片 `Y` (`G(X)`)。\n",
        "\n",
        "![Cyclegan 模型](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cyclegan_model.png?raw=1)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "8ju9Wyw87MRW",
        "colab": {}
      },
      "source": [
        "OUTPUT_CHANNELS = 3\n",
        "\n",
        "generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')\n",
        "generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')\n",
        "\n",
        "discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)\n",
        "discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wDaGZ3WpZUyw",
        "colab": {}
      },
      "source": [
        "to_zebra = generator_g(sample_horse)\n",
        "to_horse = generator_f(sample_zebra)\n",
        "plt.figure(figsize=(8, 8))\n",
        "contrast = 8\n",
        "\n",
        "imgs = [sample_horse, to_zebra, sample_zebra, to_horse]\n",
        "title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']\n",
        "\n",
        "for i in range(len(imgs)):\n",
        "  plt.subplot(2, 2, i+1)\n",
        "  plt.title(title[i])\n",
        "  if i % 2 == 0:\n",
        "    plt.imshow(imgs[i][0] * 0.5 + 0.5)\n",
        "  else:\n",
        "    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "O5MhJmxyZiy9",
        "colab": {}
      },
      "source": [
        "plt.figure(figsize=(8, 8))\n",
        "\n",
        "plt.subplot(121)\n",
        "plt.title('Is a real zebra?')\n",
        "plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')\n",
        "\n",
        "plt.subplot(122)\n",
        "plt.title('Is a real horse?')\n",
        "plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')\n",
        "\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0FMYgY_mPfTi"
      },
      "source": [
        "## 损失函数"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JRqt02lupRn8"
      },
      "source": [
        "在 CycleGAN 中，没有可训练的成对数据，因此无法保证输入 `x` 和 目标 `y` 数据对在训练期间是有意义的。所以为了强制网络学习正确的映射，作者提出了循环一致损失。\n",
        "\n",
        "判别器损失和生成器损失和 [pix2pix](https://google.tensorflow.cn/tutorials/generative/pix2pix#define_the_loss_functions_and_the_optimizer) 中所使用的类似。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "cyhxTuvJyIHV",
        "colab": {}
      },
      "source": [
        "LAMBDA = 10"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Q1Xbz5OaLj5C",
        "colab": {}
      },
      "source": [
        "loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wkMNfBWlT-PV",
        "colab": {}
      },
      "source": [
        "def discriminator_loss(real, generated):\n",
        "  real_loss = loss_obj(tf.ones_like(real), real)\n",
        "\n",
        "  generated_loss = loss_obj(tf.zeros_like(generated), generated)\n",
        "\n",
        "  total_disc_loss = real_loss + generated_loss\n",
        "\n",
        "  return total_disc_loss * 0.5"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "90BIcCKcDMxz",
        "colab": {}
      },
      "source": [
        "def generator_loss(generated):\n",
        "  return loss_obj(tf.ones_like(generated), generated)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5iIWQzVF7f9e"
      },
      "source": [
        "循环一致意味着结果应接近原始输出。例如，将一句英文译为法文，随后再从法文翻译回英文，最终的结果句应与原始句输入相同。\n",
        "\n",
        "在循环一致损失中，\n",
        "\n",
        "* 图片 $X$ 通过生成器 $G$ 传递，该生成器生成图片 $\\hat{Y}$。\n",
        "* 生成的图片 $\\hat{Y}$ 通过生成器 $F$ 传递，循环生成图片 $\\hat{X}$。\n",
        "* 在 $X$ 和 $\\hat{X}$ 之间计算平均绝对误差。\n",
        "\n",
        "$$forward\\ cycle\\ consistency\\ loss: X -> G(X) -> F(G(X)) \\sim \\hat{X}$$\n",
        "\n",
        "$$backward\\ cycle\\ consistency\\ loss: Y -> F(Y) -> G(F(Y)) \\sim \\hat{Y}$$\n",
        "\n",
        "\n",
        "![循环损失](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/cycle_loss.png?raw=1)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "NMpVGj_sW6Vo",
        "colab": {}
      },
      "source": [
        "def calc_cycle_loss(real_image, cycled_image):\n",
        "  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))\n",
        "  \n",
        "  return LAMBDA * loss1"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "U-tJL-fX0Mq7"
      },
      "source": [
        "如上所示，生成器 $G$ 负责将图片 $X$ 转换为 $Y$。一致性损失表明，如果您将图片 $Y$ 馈送给生成器 $G$，它应当生成真实图片 $Y$ 或接近于 $Y$ 的图片。\n",
        "\n",
        "$$Identity\\ loss = |G(Y) - Y| + |F(X) - X|$$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "05ywEH680Aud",
        "colab": {}
      },
      "source": [
        "def identity_loss(real_image, same_image):\n",
        "  loss = tf.reduce_mean(tf.abs(real_image - same_image))\n",
        "  return LAMBDA * 0.5 * loss"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "G-vjRM7IffTT"
      },
      "source": [
        "为所有生成器和判别器初始化优化器。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "iWCn_PVdEJZ7",
        "colab": {}
      },
      "source": [
        "generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
        "generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
        "\n",
        "discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)\n",
        "discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aKUZnDiqQrAh"
      },
      "source": [
        "## Checkpoints"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "WJnftd5sQsv6",
        "colab": {}
      },
      "source": [
        "checkpoint_path = \"./checkpoints/train\"\n",
        "\n",
        "ckpt = tf.train.Checkpoint(generator_g=generator_g,\n",
        "                           generator_f=generator_f,\n",
        "                           discriminator_x=discriminator_x,\n",
        "                           discriminator_y=discriminator_y,\n",
        "                           generator_g_optimizer=generator_g_optimizer,\n",
        "                           generator_f_optimizer=generator_f_optimizer,\n",
        "                           discriminator_x_optimizer=discriminator_x_optimizer,\n",
        "                           discriminator_y_optimizer=discriminator_y_optimizer)\n",
        "\n",
        "ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)\n",
        "\n",
        "# 如果存在检查点，恢复最新版本检查点\n",
        "if ckpt_manager.latest_checkpoint:\n",
        "  ckpt.restore(ckpt_manager.latest_checkpoint)\n",
        "  print ('Latest checkpoint restored!!')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Rw1fkAczTQYh"
      },
      "source": [
        "## 训练\n",
        "\n",
        "注意：本示例模型比论文中训练了更少的 epoch（本示例为 40 epoch，论文中为 200 epoch），以使训练时间相对于本教程是合理的。预测的准确率可能会低一些。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "NS2GWywBbAWo",
        "colab": {}
      },
      "source": [
        "EPOCHS = 40"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "RmdVsmvhPxyy",
        "colab": {}
      },
      "source": [
        "def generate_images(model, test_input):\n",
        "  prediction = model(test_input)\n",
        "    \n",
        "  plt.figure(figsize=(12, 12))\n",
        "\n",
        "  display_list = [test_input[0], prediction[0]]\n",
        "  title = ['Input Image', 'Predicted Image']\n",
        "\n",
        "  for i in range(2):\n",
        "    plt.subplot(1, 2, i+1)\n",
        "    plt.title(title[i])\n",
        "    # 获取范围在 [0, 1] 之间的像素值以绘制它。\n",
        "    plt.imshow(display_list[i] * 0.5 + 0.5)\n",
        "    plt.axis('off')\n",
        "  plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kE47ERn5fyLC"
      },
      "source": [
        "尽管训练循环看起来很复杂，其实包含四个基本步骤：\n",
        "\n",
        "* 获取预测。\n",
        "* 计算损失值。\n",
        "* 使用反向传播计算损失值。\n",
        "* 将梯度应用于优化器。"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "KBKUV2sKXDbY",
        "colab": {}
      },
      "source": [
        "@tf.function\n",
        "def train_step(real_x, real_y):\n",
        "  # persistent 设置为 Ture，因为 GradientTape 被多次应用于计算梯度。\n",
        "  with tf.GradientTape(persistent=True) as tape:\n",
        "    # 生成器 G 转换 X -> Y。\n",
        "    # 生成器 F 转换 Y -> X。\n",
        "    \n",
        "    fake_y = generator_g(real_x, training=True)\n",
        "    cycled_x = generator_f(fake_y, training=True)\n",
        "\n",
        "    fake_x = generator_f(real_y, training=True)\n",
        "    cycled_y = generator_g(fake_x, training=True)\n",
        "\n",
        "    # same_x 和 same_y 用于一致性损失。\n",
        "    same_x = generator_f(real_x, training=True)\n",
        "    same_y = generator_g(real_y, training=True)\n",
        "\n",
        "    disc_real_x = discriminator_x(real_x, training=True)\n",
        "    disc_real_y = discriminator_y(real_y, training=True)\n",
        "\n",
        "    disc_fake_x = discriminator_x(fake_x, training=True)\n",
        "    disc_fake_y = discriminator_y(fake_y, training=True)\n",
        "\n",
        "    # 计算损失。\n",
        "    gen_g_loss = generator_loss(disc_fake_y)\n",
        "    gen_f_loss = generator_loss(disc_fake_x)\n",
        "    \n",
        "    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)\n",
        "    \n",
        "    # 总生成器损失 = 对抗性损失 + 循环损失。\n",
        "    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)\n",
        "    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)\n",
        "\n",
        "    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)\n",
        "    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)\n",
        "  \n",
        "  # 计算生成器和判别器损失。\n",
        "  generator_g_gradients = tape.gradient(total_gen_g_loss, \n",
        "                                        generator_g.trainable_variables)\n",
        "  generator_f_gradients = tape.gradient(total_gen_f_loss, \n",
        "                                        generator_f.trainable_variables)\n",
        "  \n",
        "  discriminator_x_gradients = tape.gradient(disc_x_loss, \n",
        "                                            discriminator_x.trainable_variables)\n",
        "  discriminator_y_gradients = tape.gradient(disc_y_loss, \n",
        "                                            discriminator_y.trainable_variables)\n",
        "  \n",
        "  # 将梯度应用于优化器。\n",
        "  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, \n",
        "                                            generator_g.trainable_variables))\n",
        "\n",
        "  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, \n",
        "                                            generator_f.trainable_variables))\n",
        "  \n",
        "  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,\n",
        "                                                discriminator_x.trainable_variables))\n",
        "  \n",
        "  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,\n",
        "                                                discriminator_y.trainable_variables))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2M7LmLtGEMQJ",
        "colab": {}
      },
      "source": [
        "for epoch in range(EPOCHS):\n",
        "  start = time.time()\n",
        "\n",
        "  n = 0\n",
        "  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):\n",
        "    train_step(image_x, image_y)\n",
        "    if n % 10 == 0:\n",
        "      print ('.', end='')\n",
        "    n+=1\n",
        "\n",
        "  clear_output(wait=True)\n",
        "  # 使用一致的图像（sample_horse），以便模型的进度清晰可见。\n",
        "  generate_images(generator_g, sample_horse)\n",
        "\n",
        "  if (epoch + 1) % 5 == 0:\n",
        "    ckpt_save_path = ckpt_manager.save()\n",
        "    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,\n",
        "                                                         ckpt_save_path))\n",
        "\n",
        "  print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
        "                                                      time.time()-start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1RGysMU_BZhx"
      },
      "source": [
        "## 使用测试数据集进行生成"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "KUgSnmy2nqSP",
        "colab": {}
      },
      "source": [
        "# 在测试数据集上运行训练的模型。\n",
        "for inp in test_horses.take(5):\n",
        "  generate_images(generator_g, inp)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ABGiHY6fE02b"
      },
      "source": [
        "## 下一步\n",
        "本教程展示了如何从 [Pix2Pix](https://tensorflow.google.cn/tutorials/generative/pix2pix) 教程实现的生成器和判别器开始实现 CycleGAN。 下一步，您可以尝试使用一个来源于 [TensorFlow 数据集](https://tensorflow.google.cn/datasets/datasets#cycle_gan)的不同的数据集。\n",
        "\n",
        "您也可以训练更多的 epoch 以改进结果，或者可以实现[论文](https://arxiv.org/abs/1703.10593)中所使用的改进 ResNet 生成器来代替这里使用的 U-Net 生成器。"
      ]
    }
  ]
}